Skip to content

Conversation

@Lancern
Copy link
Member

@Lancern Lancern commented Oct 22, 2024

This PR adds several new builders for llvm.intr.assume that build the operation with additional operand bundles.

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-sme

Author: Sirui Mu (Lancern)

Changes

This PR adds two intrinsic operations, namely llvm.intr.assume.align and llvm.intr.assume.separate_storage. Module translation translates both operations to intrinsic calls to @<!-- -->llvm.assume, with different assume operand bundles.

This PR also adds a new builder to llvm.intr.assume to make it easier to build the operation with assume operand bundles.


Full diff: https://github.com/llvm/llvm-project/pull/113317.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td (+35-17)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+1-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+3-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+39-10)
  • (added) mlir/test/Dialect/LLVMIR/assume.mlir (+20)
  • (modified) mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
           /*int numResults=*/numResults,
+          /*bit enableMlirBuilder=*/1,
           /*bit requiresAccessGroup=*/0,
           /*bit requiresAliasAnalysis=*/0,
           /*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
 #include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$cond)>
+    OpBuilder<(ins "Value":$cond)>,
+    OpBuilder<(ins "Value":$cond,
+                   "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
   ];
 
   let hasVerifier = 1;
 }
 
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+                        string opBundleTag>
+    : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+                  /*overloadedOperands=*/[], /*traits=*/[],
+                  /*numResults=*/0, /*enumName=*/"assume",
+                  /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+                  /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+                  /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+                  /*immArgAttrNames=*/[],
+                  /*opBundleOperandPositions=*/[opBundleOperandPositions],
+                  /*opBundleTags=*/[opBundleTag]> {
+  dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+  let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+    : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+  let arguments = !con(
+    args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
                                             [Pure, SameOperandsAndResultType]> {
   let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
-                      bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0, bit requiresOpBundles = 0,
-                      list<int> immArgPositions = [],
-                      list<string> immArgAttrNames = []>
+                      bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+                      bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                      bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                      list<string> immArgAttrNames = [],
+                      list<list<int>> opBundleOperandPositions = [],
+                      list<string> opBundleTags = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
         !if(!gt(requiresAccessGroup, 0),
             [DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
   string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
   string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
     "StringLiteral(\"" # name # "\")"), ", ") # "}";
+  string opBundleOperandPositionsCpp = "{" # !interleave(
+    !foreach(positions, opBundleOperandPositions,
+      "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+    ), ", ") # "}";
+  string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+    "StringLiteral(\"" # tag # "\")"), ", ") # "}";
   string baseLlvmBuilder = [{
     auto *inst = LLVM::detail::createIntrinsicCall(
       builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
         enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
-        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+        immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+        opBundleTagsCpp], ",") # [{);
     (void) inst;
     }];
   string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
       $_location, resultTypes, mlirOperands, mlirAttrs);
     }];
   string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
-  let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+  let mlirBuilder = !if(enableMlirBuilder,
+    baseMlirBuilder # !if(!gt(requiresFastmath, 0),
       "moduleImport.setFastmathFlagsAttr(inst, op);", "")
-    # baseMlirBuilderCoda;
+    # baseMlirBuilderCoda, "");
 
   // Code for handling a `range` attribute that holds the constant range of the
   // intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 // the intrinsic into the LLVM dialect and prefixes its name with "intr.".
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
-                  int numResults, bit requiresAccessGroup = 0,
+                  int numResults, string enumName = "",
+                  bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
                   bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
-                  bit requiresOpBundles = 0,
-                  list<int> immArgPositions = [],
-                  list<string> immArgAttrNames = []>
-    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+                  bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                  list<string> immArgAttrNames = [],
+                  list<list<int>> opBundleOperandPositions = [],
+                  list<string> opBundleTags = []>
+    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+                      !if(!empty(enumName), !subst(".", "_", mnem), enumName),
                       overloadedResults, overloadedOperands, traits,
-                      numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath, requiresOpBundles, immArgPositions,
-                      immArgAttrNames>;
+                      numResults, enableMlirBuilder, requiresAccessGroup,
+                      requiresAliasAnalysis, requiresFastmath,
+                      requiresOpBundles, immArgPositions, immArgAttrNames,
+                      opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             bit requiresAliasAnalysis = 0,
                             bit requiresOpBundles = 0,
                             list<int> immArgPositions = [],
-                            list<string> immArgAttrNames = []>
+                            list<string> immArgAttrNames = [],
+                            list<list<int>> opBundleOperandPositions = [],
+                            list<string> opBundleTags = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   requiresAccessGroup, requiresAliasAnalysis,
                   /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
-                  immArgAttrNames>;
+                  immArgAttrNames, opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> immArgPositions = [],
                            list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
                   requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
                   immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
   list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase<ROCDL_Dialect,  mnemonic,
     "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
-    overloadedOperands, traits, numResults, requiresAccessGroup,
+    overloadedOperands, traits, numResults, 1, requiresAccessGroup,
     requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames);
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags);
 
 } // namespace detail
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
                /*op_bundle_tags=*/{});
 }
 
+void LLVM::AssumeOp::build(
+    OpBuilder &builder, OperationState &state, mlir::Value cond,
+    ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+  SmallVector<mlir::ValueRange> opBundleOperands;
+  SmallVector<mlir::Attribute> opBundleTags;
+  opBundleOperands.reserve(opBundles.size());
+  opBundleTags.reserve(opBundles.size());
+
+  for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+    opBundleOperands.emplace_back(bundle.inputs());
+    opBundleTags.push_back(
+        StringAttr::get(builder.getContext(), bundle.getTag()));
+  }
+
+  auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+  return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames) {
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags) {
   assert(immArgPositions.size() == immArgAttrNames.size() &&
          "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
          "length");
+  assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+         "operand bundles and tags do not match");
 
   SmallVector<llvm::OperandBundleDef> opBundles;
-  size_t numOpBundleOperands = 0;
+
+  size_t numVariadicOpBundleOperands = 0;
   auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
   auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
   if (opBundleSizesAttr && opBundleTagsAttr) {
     ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
     assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
            "operand bundles and tags do not match");
 
-    numOpBundleOperands =
+    numVariadicOpBundleOperands =
         std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
-    assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+    assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
            "operand bundle operands is more than the number of operands");
 
-    ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+    ValueRange operands =
+        intrOp->getOperands().take_back(numVariadicOpBundleOperands);
     size_t nextOperandIdx = 0;
     opBundles.reserve(opBundleSizesAttr.size());
 
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
 
   // Map operands and attributes to LLVM values.
-  auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+  auto opOperands =
+      intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
   auto operands = moduleTranslation.lookupValues(opOperands);
-  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+  // Map operand bundle operands to LLVM operand bundles.
+  DenseSet<unsigned> opBundleOperandPositionsSet;
+  for (auto [positions, tag] :
+       llvm::zip(opBundleOperandPositions, opBundleTags)) {
+    opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+    SmallVector<llvm::Value *> bundleArgs;
+    bundleArgs.reserve(positions.size());
+    for (unsigned idx : positions) {
+      assert(idx < operands.size() &&
+             "op bundle operand index is out of range");
+      bundleArgs.push_back(operands[idx]);
+    }
+
+    opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+  }
+
+  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+                                  opBundleOperandPositionsSet.size());
   for (auto [immArgPos, immArgName] :
        llvm::zip(immArgPositions, immArgAttrNames)) {
     auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
   unsigned opArg = 0;
   for (auto &arg : args) {
-    if (!arg)
+    if (!arg) {
+      while (opBundleOperandPositionsSet.contains(opArg))
+        ++opArg;
       arg = operands[opArg++];
+    }
   }
 
   // Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
       module, intrinsic, overloadedTypes);
 
+  llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+               << ", num op bundles = " << opBundles.size() << "\n";
   return builder.CreateCall(llvmIntr, args, opBundles);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+  llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+  llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
   os << ", ";
   printBracketedRange(traits, os);
-  os << ", " << intr.getNumResults() << ", "
+  os << ", " << intr.getNumResults() << ", \"\", 1, "
      << (requiresAccessGroup ? "1" : "0") << ", "
      << (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
      << (operands.empty() ? "" : " ");

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir-core

Author: Sirui Mu (Lancern)

Changes

This PR adds two intrinsic operations, namely llvm.intr.assume.align and llvm.intr.assume.separate_storage. Module translation translates both operations to intrinsic calls to @<!-- -->llvm.assume, with different assume operand bundles.

This PR also adds a new builder to llvm.intr.assume to make it easier to build the operation with assume operand bundles.


Full diff: https://github.com/llvm/llvm-project/pull/113317.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td (+35-17)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+1-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+3-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+39-10)
  • (added) mlir/test/Dialect/LLVMIR/assume.mlir (+20)
  • (modified) mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
           /*int numResults=*/numResults,
+          /*bit enableMlirBuilder=*/1,
           /*bit requiresAccessGroup=*/0,
           /*bit requiresAliasAnalysis=*/0,
           /*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
 #include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$cond)>
+    OpBuilder<(ins "Value":$cond)>,
+    OpBuilder<(ins "Value":$cond,
+                   "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
   ];
 
   let hasVerifier = 1;
 }
 
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+                        string opBundleTag>
+    : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+                  /*overloadedOperands=*/[], /*traits=*/[],
+                  /*numResults=*/0, /*enumName=*/"assume",
+                  /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+                  /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+                  /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+                  /*immArgAttrNames=*/[],
+                  /*opBundleOperandPositions=*/[opBundleOperandPositions],
+                  /*opBundleTags=*/[opBundleTag]> {
+  dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+  let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+    : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+  let arguments = !con(
+    args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
                                             [Pure, SameOperandsAndResultType]> {
   let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
-                      bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0, bit requiresOpBundles = 0,
-                      list<int> immArgPositions = [],
-                      list<string> immArgAttrNames = []>
+                      bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+                      bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                      bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                      list<string> immArgAttrNames = [],
+                      list<list<int>> opBundleOperandPositions = [],
+                      list<string> opBundleTags = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
         !if(!gt(requiresAccessGroup, 0),
             [DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
   string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
   string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
     "StringLiteral(\"" # name # "\")"), ", ") # "}";
+  string opBundleOperandPositionsCpp = "{" # !interleave(
+    !foreach(positions, opBundleOperandPositions,
+      "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+    ), ", ") # "}";
+  string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+    "StringLiteral(\"" # tag # "\")"), ", ") # "}";
   string baseLlvmBuilder = [{
     auto *inst = LLVM::detail::createIntrinsicCall(
       builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
         enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
-        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+        immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+        opBundleTagsCpp], ",") # [{);
     (void) inst;
     }];
   string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
       $_location, resultTypes, mlirOperands, mlirAttrs);
     }];
   string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
-  let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+  let mlirBuilder = !if(enableMlirBuilder,
+    baseMlirBuilder # !if(!gt(requiresFastmath, 0),
       "moduleImport.setFastmathFlagsAttr(inst, op);", "")
-    # baseMlirBuilderCoda;
+    # baseMlirBuilderCoda, "");
 
   // Code for handling a `range` attribute that holds the constant range of the
   // intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 // the intrinsic into the LLVM dialect and prefixes its name with "intr.".
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
-                  int numResults, bit requiresAccessGroup = 0,
+                  int numResults, string enumName = "",
+                  bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
                   bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
-                  bit requiresOpBundles = 0,
-                  list<int> immArgPositions = [],
-                  list<string> immArgAttrNames = []>
-    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+                  bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                  list<string> immArgAttrNames = [],
+                  list<list<int>> opBundleOperandPositions = [],
+                  list<string> opBundleTags = []>
+    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+                      !if(!empty(enumName), !subst(".", "_", mnem), enumName),
                       overloadedResults, overloadedOperands, traits,
-                      numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath, requiresOpBundles, immArgPositions,
-                      immArgAttrNames>;
+                      numResults, enableMlirBuilder, requiresAccessGroup,
+                      requiresAliasAnalysis, requiresFastmath,
+                      requiresOpBundles, immArgPositions, immArgAttrNames,
+                      opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             bit requiresAliasAnalysis = 0,
                             bit requiresOpBundles = 0,
                             list<int> immArgPositions = [],
-                            list<string> immArgAttrNames = []>
+                            list<string> immArgAttrNames = [],
+                            list<list<int>> opBundleOperandPositions = [],
+                            list<string> opBundleTags = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   requiresAccessGroup, requiresAliasAnalysis,
                   /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
-                  immArgAttrNames>;
+                  immArgAttrNames, opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> immArgPositions = [],
                            list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
                   requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
                   immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
   list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase<ROCDL_Dialect,  mnemonic,
     "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
-    overloadedOperands, traits, numResults, requiresAccessGroup,
+    overloadedOperands, traits, numResults, 1, requiresAccessGroup,
     requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames);
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags);
 
 } // namespace detail
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
                /*op_bundle_tags=*/{});
 }
 
+void LLVM::AssumeOp::build(
+    OpBuilder &builder, OperationState &state, mlir::Value cond,
+    ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+  SmallVector<mlir::ValueRange> opBundleOperands;
+  SmallVector<mlir::Attribute> opBundleTags;
+  opBundleOperands.reserve(opBundles.size());
+  opBundleTags.reserve(opBundles.size());
+
+  for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+    opBundleOperands.emplace_back(bundle.inputs());
+    opBundleTags.push_back(
+        StringAttr::get(builder.getContext(), bundle.getTag()));
+  }
+
+  auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+  return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames) {
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags) {
   assert(immArgPositions.size() == immArgAttrNames.size() &&
          "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
          "length");
+  assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+         "operand bundles and tags do not match");
 
   SmallVector<llvm::OperandBundleDef> opBundles;
-  size_t numOpBundleOperands = 0;
+
+  size_t numVariadicOpBundleOperands = 0;
   auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
   auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
   if (opBundleSizesAttr && opBundleTagsAttr) {
     ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
     assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
            "operand bundles and tags do not match");
 
-    numOpBundleOperands =
+    numVariadicOpBundleOperands =
         std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
-    assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+    assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
            "operand bundle operands is more than the number of operands");
 
-    ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+    ValueRange operands =
+        intrOp->getOperands().take_back(numVariadicOpBundleOperands);
     size_t nextOperandIdx = 0;
     opBundles.reserve(opBundleSizesAttr.size());
 
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
 
   // Map operands and attributes to LLVM values.
-  auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+  auto opOperands =
+      intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
   auto operands = moduleTranslation.lookupValues(opOperands);
-  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+  // Map operand bundle operands to LLVM operand bundles.
+  DenseSet<unsigned> opBundleOperandPositionsSet;
+  for (auto [positions, tag] :
+       llvm::zip(opBundleOperandPositions, opBundleTags)) {
+    opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+    SmallVector<llvm::Value *> bundleArgs;
+    bundleArgs.reserve(positions.size());
+    for (unsigned idx : positions) {
+      assert(idx < operands.size() &&
+             "op bundle operand index is out of range");
+      bundleArgs.push_back(operands[idx]);
+    }
+
+    opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+  }
+
+  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+                                  opBundleOperandPositionsSet.size());
   for (auto [immArgPos, immArgName] :
        llvm::zip(immArgPositions, immArgAttrNames)) {
     auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
   unsigned opArg = 0;
   for (auto &arg : args) {
-    if (!arg)
+    if (!arg) {
+      while (opBundleOperandPositionsSet.contains(opArg))
+        ++opArg;
       arg = operands[opArg++];
+    }
   }
 
   // Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
       module, intrinsic, overloadedTypes);
 
+  llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+               << ", num op bundles = " << opBundles.size() << "\n";
   return builder.CreateCall(llvmIntr, args, opBundles);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+  llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+  llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
   os << ", ";
   printBracketedRange(traits, os);
-  os << ", " << intr.getNumResults() << ", "
+  os << ", " << intr.getNumResults() << ", \"\", 1, "
      << (requiresAccessGroup ? "1" : "0") << ", "
      << (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
      << (operands.empty() ? "" : " ");

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir

Author: Sirui Mu (Lancern)

Changes

This PR adds two intrinsic operations, namely llvm.intr.assume.align and llvm.intr.assume.separate_storage. Module translation translates both operations to intrinsic calls to @<!-- -->llvm.assume, with different assume operand bundles.

This PR also adds a new builder to llvm.intr.assume to make it easier to build the operation with assume operand bundles.


Full diff: https://github.com/llvm/llvm-project/pull/113317.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td (+35-17)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+1-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+3-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+39-10)
  • (added) mlir/test/Dialect/LLVMIR/assume.mlir (+20)
  • (modified) mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
           /*int numResults=*/numResults,
+          /*bit enableMlirBuilder=*/1,
           /*bit requiresAccessGroup=*/0,
           /*bit requiresAliasAnalysis=*/0,
           /*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
 #include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$cond)>
+    OpBuilder<(ins "Value":$cond)>,
+    OpBuilder<(ins "Value":$cond,
+                   "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
   ];
 
   let hasVerifier = 1;
 }
 
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+                        string opBundleTag>
+    : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+                  /*overloadedOperands=*/[], /*traits=*/[],
+                  /*numResults=*/0, /*enumName=*/"assume",
+                  /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+                  /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+                  /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+                  /*immArgAttrNames=*/[],
+                  /*opBundleOperandPositions=*/[opBundleOperandPositions],
+                  /*opBundleTags=*/[opBundleTag]> {
+  dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+  let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+    : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+  let arguments = !con(
+    args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+  let assemblyFormat = [{
+    $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
                                             [Pure, SameOperandsAndResultType]> {
   let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
-                      bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0, bit requiresOpBundles = 0,
-                      list<int> immArgPositions = [],
-                      list<string> immArgAttrNames = []>
+                      bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+                      bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                      bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                      list<string> immArgAttrNames = [],
+                      list<list<int>> opBundleOperandPositions = [],
+                      list<string> opBundleTags = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
         !if(!gt(requiresAccessGroup, 0),
             [DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
   string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
   string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
     "StringLiteral(\"" # name # "\")"), ", ") # "}";
+  string opBundleOperandPositionsCpp = "{" # !interleave(
+    !foreach(positions, opBundleOperandPositions,
+      "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+    ), ", ") # "}";
+  string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+    "StringLiteral(\"" # tag # "\")"), ", ") # "}";
   string baseLlvmBuilder = [{
     auto *inst = LLVM::detail::createIntrinsicCall(
       builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
         enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
-        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+        immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+        opBundleTagsCpp], ",") # [{);
     (void) inst;
     }];
   string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
       $_location, resultTypes, mlirOperands, mlirAttrs);
     }];
   string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
-  let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+  let mlirBuilder = !if(enableMlirBuilder,
+    baseMlirBuilder # !if(!gt(requiresFastmath, 0),
       "moduleImport.setFastmathFlagsAttr(inst, op);", "")
-    # baseMlirBuilderCoda;
+    # baseMlirBuilderCoda, "");
 
   // Code for handling a `range` attribute that holds the constant range of the
   // intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 // the intrinsic into the LLVM dialect and prefixes its name with "intr.".
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
-                  int numResults, bit requiresAccessGroup = 0,
+                  int numResults, string enumName = "",
+                  bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
                   bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
-                  bit requiresOpBundles = 0,
-                  list<int> immArgPositions = [],
-                  list<string> immArgAttrNames = []>
-    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+                  bit requiresOpBundles = 0, list<int> immArgPositions = [],
+                  list<string> immArgAttrNames = [],
+                  list<list<int>> opBundleOperandPositions = [],
+                  list<string> opBundleTags = []>
+    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+                      !if(!empty(enumName), !subst(".", "_", mnem), enumName),
                       overloadedResults, overloadedOperands, traits,
-                      numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath, requiresOpBundles, immArgPositions,
-                      immArgAttrNames>;
+                      numResults, enableMlirBuilder, requiresAccessGroup,
+                      requiresAliasAnalysis, requiresFastmath,
+                      requiresOpBundles, immArgPositions, immArgAttrNames,
+                      opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             bit requiresAliasAnalysis = 0,
                             bit requiresOpBundles = 0,
                             list<int> immArgPositions = [],
-                            list<string> immArgAttrNames = []>
+                            list<string> immArgAttrNames = [],
+                            list<list<int>> opBundleOperandPositions = [],
+                            list<string> opBundleTags = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   requiresAccessGroup, requiresAliasAnalysis,
                   /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
-                  immArgAttrNames>;
+                  immArgAttrNames, opBundleOperandPositions, opBundleTags>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> immArgPositions = [],
                            list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+                  /*enumName=*/"", /*enableMlirBuilder=*/1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
                   requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
                   immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
   list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase<ROCDL_Dialect,  mnemonic,
     "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
-    overloadedOperands, traits, numResults, requiresAccessGroup,
+    overloadedOperands, traits, numResults, 1, requiresAccessGroup,
     requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames);
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags);
 
 } // namespace detail
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
                /*op_bundle_tags=*/{});
 }
 
+void LLVM::AssumeOp::build(
+    OpBuilder &builder, OperationState &state, mlir::Value cond,
+    ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+  SmallVector<mlir::ValueRange> opBundleOperands;
+  SmallVector<mlir::Attribute> opBundleTags;
+  opBundleOperands.reserve(opBundles.size());
+  opBundleTags.reserve(opBundles.size());
+
+  for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+    opBundleOperands.emplace_back(bundle.inputs());
+    opBundleTags.push_back(
+        StringAttr::get(builder.getContext(), bundle.getTag()));
+  }
+
+  auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+  return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
-    ArrayRef<unsigned> immArgPositions,
-    ArrayRef<StringLiteral> immArgAttrNames) {
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+    ArrayRef<StringLiteral> opBundleTags) {
   assert(immArgPositions.size() == immArgAttrNames.size() &&
          "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
          "length");
+  assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+         "operand bundles and tags do not match");
 
   SmallVector<llvm::OperandBundleDef> opBundles;
-  size_t numOpBundleOperands = 0;
+
+  size_t numVariadicOpBundleOperands = 0;
   auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
   auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
       intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
   if (opBundleSizesAttr && opBundleTagsAttr) {
     ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
     assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
            "operand bundles and tags do not match");
 
-    numOpBundleOperands =
+    numVariadicOpBundleOperands =
         std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
-    assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+    assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
            "operand bundle operands is more than the number of operands");
 
-    ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+    ValueRange operands =
+        intrOp->getOperands().take_back(numVariadicOpBundleOperands);
     size_t nextOperandIdx = 0;
     opBundles.reserve(opBundleSizesAttr.size());
 
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
 
   // Map operands and attributes to LLVM values.
-  auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+  auto opOperands =
+      intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
   auto operands = moduleTranslation.lookupValues(opOperands);
-  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+  // Map operand bundle operands to LLVM operand bundles.
+  DenseSet<unsigned> opBundleOperandPositionsSet;
+  for (auto [positions, tag] :
+       llvm::zip(opBundleOperandPositions, opBundleTags)) {
+    opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+    SmallVector<llvm::Value *> bundleArgs;
+    bundleArgs.reserve(positions.size());
+    for (unsigned idx : positions) {
+      assert(idx < operands.size() &&
+             "op bundle operand index is out of range");
+      bundleArgs.push_back(operands[idx]);
+    }
+
+    opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+  }
+
+  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+                                  opBundleOperandPositionsSet.size());
   for (auto [immArgPos, immArgName] :
        llvm::zip(immArgPositions, immArgAttrNames)) {
     auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   }
   unsigned opArg = 0;
   for (auto &arg : args) {
-    if (!arg)
+    if (!arg) {
+      while (opBundleOperandPositionsSet.contains(opArg))
+        ++opArg;
       arg = operands[opArg++];
+    }
   }
 
   // Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
       module, intrinsic, overloadedTypes);
 
+  llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+               << ", num op bundles = " << opBundles.size() << "\n";
   return builder.CreateCall(llvmIntr, args, opBundles);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+  llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+  llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
   os << ", ";
   printBracketedRange(traits, os);
-  os << ", " << intr.getNumResults() << ", "
+  os << ", " << intr.getNumResults() << ", \"\", 1, "
      << (requiresAccessGroup ? "1" : "0") << ", "
      << (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
      << (operands.empty() ? "" : " ");

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic does not exist in LLVM IR right?

LLVM dialect operations usually have a 1:1 mapping to an LLVM IR instruction/intrinsic. There are a few exceptions for example for constants, since they are not instructions in LLVM IR. In these cases we always prefix the operation with llvm.mlir. to clarify this operation is special.

So if we add a new intrinsic it should be prefix the new intrinsics with llvm.intr.mlir. or similar. However, we need a good argument for this since this contradicts the LLVM dialect rational (see third paragraph of https://mlir.llvm.org/docs/Dialects/LLVM/).

Is your plan to match on these OPs somehow or would it be good enough to have a convenience builder that builds normal assume operations with the specific tags?

@ftynse what would be your take on adding specialized intrinsics that do not exist in LLVM IR (AFAIK there is no prior art)?

@joker-eph
Copy link
Collaborator

Why aren't we modeling the llvm.assume instrinsic 1:1 ? Is this because of the question of how to model the operand bundles?

@Lancern
Copy link
Member Author

Lancern commented Oct 23, 2024

Why aren't we modeling the llvm.assume instrinsic 1:1 ? Is this because of the question of how to model the operand bundles?

Yes I try to model the different assume operand bundles with dedicated operations, since I do not find a way to check whether an arbitrary operand bundle passed to llvm.intr.assume is valid or not. Is this possible? Or is it necessary to check at MLIR level that the tags passed to llvm.intr.assume are valid?

Is your plan to match on these OPs somehow or would it be good enough to have a convenience builder that builds normal assume operations with the specific tags?

I agree that having convenience builders for assume operand bundles would be great enough, since the only point of this PR is to have some degree of convenience when building assume intrinsics with tags.

@gysit
Copy link
Contributor

gysit commented Oct 23, 2024

Yes I try to model the different assume operand bundles with dedicated operations, since I do not find a way to check whether an arbitrary operand bundle passed to llvm.intr.assume is valid or not. Is this possible?

You could add a verifier to AssumeOp that checks if the operand bundles are correct. If you do that, then it is important to replicate the logic in LLVM's Verifier.cpp. At the moment, only very few intrinsics in LLVM dialect implement a verifier though. Instead, the verification happens after lowering to LLVM proper. I am fine with either of these solutions!

I agree that having convenience builders for assume operand bundles would be great enough, since the only point of this PR is to have some degree of convenience when building assume intrinsics with tags.

If the goal is to simplify the lowering, then I would go the convenience builder route and avoid introducing custom intrinsics for every operand bundle type.

Having different operations may be interesting if we want to use the assume information in transformations. However, if we want this then it would probably make sense to have a separate Assume dialect. That would require an RFC though especially since there is some overlap with existing dialects, such as memref, which implement some of this functionality.

@Lancern Lancern changed the title [mlir][LLVM] Add dedicated operations for assume align and separate_storage [mlir][LLVM] Add builders for llvm.intr.assume Oct 24, 2024
@Lancern
Copy link
Member Author

Lancern commented Oct 24, 2024

I have updated the patch and kept only the new builders for llvm.intr.assume. Specifically, the new patch adds 4 builders:

OpBuilder<(ins "Value":$cond,
               "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>

OpBuilder<(ins "Value":$cond, "llvm::StringRef":$tag, "ValueRange":$args)>

OpBuilder<(ins "Value":$cond, "AssumeAlignTag":$tag, "Value":$ptr,
               "Value":$align)>

OpBuilder<(ins "Value":$cond, "AssumeSeparateStorageTag":$tag,
               "Value":$ptr1, "Value":$ptr2)>

The first two are general and the last two are for specific tags. Currently I only add builders for align and separate_storage since this is the only two tags I'm interested in. We could add more builders if people are interested in other tags.

As for tests, since we don't have any tests yet for operation builders, I assume it's safe to ignore them for now.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the comments are addressed.

A low-tech alternative maybe to just define some static getters in the extraClassDeclaration that return the tag strings (e.g. static StringRef getAlignTag() / static StringRef getSeparateStorageTag()) and then have one builder that takes a StringRef and a ValueRange. However, that way there may be a mismatch between tag and the number of arguments. I am fine with both approaches!

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Lancern Lancern merged commit 93da642 into llvm:main Oct 27, 2024
8 checks passed
@Lancern Lancern deleted the mlir-llvm-assume branch October 27, 2024 03:52
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
This patch adds several new builders for llvm.intr.assume that build the
operation with additional operand bundles.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants